#### Make plots from cluster permutation statistics ####
# This is a series of functions to use for making plots from data obtained in 5_cluster_permutation


# Init --------------------------------------------------------------------

library(tidyverse)
library(cowplot)
source('./R/config.R')
source('./R/5_cluster_permutation_functions.R')
source('./R/6_cluster_permutation_functions.R')

one.sample.select_columns <- c("UnitID", "norm.time", "z.score")
one.sample.group_columns <- c("Area.brain", "TreatmentID", "Event")

iterations <- 1:10000
one.sample.observed.file <- "One.sample_observed.rds"
one.sample.permuted.file <- "One.sample_permuted.rds"

two.sample.observed.file <- "Two.sample_observed.rds"
two.sample.permuted.file <- "Two.sample_permuted.rds"

# Load and filter original data -------------------------------------------

Data_original <- read_rds(path = paste0(output.folder, Units.thresholded))

# prepare data for ALL treatments

Filtered.data <- Data_original %>%
  mutate(Sorted = if_else(str_detect(Unit, pattern = "SPK..i"), "Unsorted", "Sorted")) %>%
  select_at(c(group.columns, "Unit", "Event", "Direction","Sorted")) %>% 
  unique() %>%
  spread(key = Event, value = Direction)

Filtered.data[is.na(Filtered.data)] <- "Not available"

Filtered.data <- Filtered.data %>%
  filter_all(all_vars(. != "Not computed")) %>% 
  gather(key = "Event", value = "Direction", one_of(perievent.columns))

Filtered.data <- left_join(Filtered.data, Data_original, by = c(group.columns, "Unit", "Event", "Direction"))

Filtered.data <- Filtered.data %>% filter(Sorted == "Sorted") %>%
  unite(UnitID, File.name, SessionID, Unit) %>% select(-FileID, -Sorted)

Data_original <- NULL

Filtered.ids <- Filtered.data %>% select(Area.brain, TreatmentID, UnitID) %>% unique(.)

# heatmap <- Filtered.data %>% filter(Area.brain == "STN" & 
#                                       TreatmentID == "Baseline",
#                                       norm.time == 0) %>%
#   select(UnitID, Direction, Event) %>%
#   unique() %>%
#   mutate(Direction_binary = case_when(Direction == "Up" ~ 1,
#                                       Direction == "Down" ~ -1,
#                                       TRUE ~ 0)) %>%
#   select(-Direction) %>%
#   spread(key = "Event", value = "Direction_binary")
# 
# heatmap.mat <- heatmap %>% select(-UnitID) %>% as.matrix()
# rownames(heatmap.mat) <- heatmap$UnitID
# pheatmap(heatmap.mat)

# One sample t.test -----------------------------------------------------------------------------------------------
#### one sample t.test ####

d <- ""
One.sample.observed <- read_rds(path = paste0(output.folder, d, one.sample.observed.file))
One.sample.permuted <- read_rds(path = paste0(output.folder, d, one.sample.permuted.file))

time.data <- Filtered.data %>% 
  group_by(Area.brain, TreatmentID, Event, norm.time) %>%
  summarise(avg = mean(z.score), N = n(), sd = sd(z.score), sem = sd/sqrt(N)) %>%
  ungroup() %>%
  rename(time = norm.time, z.score = avg)

One.sample.observed$time <- as.numeric(as.character(One.sample.observed$time))

time.data <- left_join(time.data, One.sample.observed, 
                        by = c("Area.brain", "TreatmentID", "Event", "time"))

ggplot(data = time.data  %>% filter(Area.brain == "STN" & 
                                      TreatmentID == "Baseline" & !is.na(time)), 
       aes(x = time, y = z.score)) +
  geom_vline(xintercept = 0, linetype = "dashed", color = "grey50") +
  geom_hline(yintercept = 0, color = "grey50") +
  geom_tile(aes(x = time, y = 0, fill = T.val)) + #max(z.score)* 1.3
  geom_line(color = "black") + 
  geom_ribbon(aes(ymin = z.score - sem, ymax = z.score + sem), alpha = 0.2, fill = "grey") +
  facet_wrap(~Event) + 
  theme_bw() + 
  #theme(legend.position = "none") +
  scale_fill_distiller(palette = "RdYlBu", na.value = NA)

# 
# precue <- One.sample.observed %>% filter(Area.brain == "STN" & 
#                                            TreatmentID == "Baseline" & 
#                                            Event == "Correct_Go_Pokes" &
#                                            n > 1)
# 
# precue.perm <- One.sample.permuted %>% filter(Area.brain == "STN" & 
#                                                 TreatmentID == "Baseline" & 
#                                                 Event == "Correct_Go_Pokes" &
#                                                 n > 1) %>%
#   group_by(iteration) %>% 
#   summarise(T.val.max = max(T.val), T.val.min = min(T.val),
#             sum.t.max = max(sum.t), sum.t.min = min(sum.t)) %>%
#   ungroup() %>%
#   mutate(T.val = if_else(abs(T.val.max) >= abs(T.val.min), T.val.max, T.val.min),
#          t.max = if_else(abs(sum.t.max) >= abs(sum.t.min), sum.t.max, sum.t.min))
# 
# p <- (sum(abs(5.697271) < abs(precue.perm$t.max)) + 1) / (length(precue.perm$t.max) + 1)


# One sample t.test by directon ------------------------------------------------------------------------------------
#### one sample t.test ####
d <- "By_direction_"

One.sample.observed <- read_rds(path = paste0(output.folder, d, one.sample.observed.file))
One.sample.permuted <- read_rds(path = paste0(output.folder, d, one.sample.permuted.file))

time.data <- Filtered.data %>% 
  group_by(Area.brain, TreatmentID, Event, Direction, norm.time) %>%
  summarise(avg = mean(z.score), N = n(), sd = sd(z.score), sem = sd/sqrt(N)) %>%
  ungroup() %>%
  rename(time = norm.time, z.score = avg)

One.sample.observed$time <- as.numeric(as.character(One.sample.observed$time))

time.data <- left_join(time.data, One.sample.observed, 
                       by = c("Area.brain", "TreatmentID", "Event", "Direction", "time"))

ggplot(data = time.data  %>% filter(Area.brain == "STN" & 
                                      TreatmentID == "BI_cmp" & !is.na(time)), 
       aes(x = time, y = z.score)) +
  geom_vline(xintercept = 0, linetype = "dashed", color = "grey50") +
  geom_hline(yintercept = 0, color = "grey50") +
  geom_line(color = "black") + 
  geom_ribbon(aes(ymin = z.score - sem, ymax = z.score + sem), alpha = 0.2, fill = "grey") +
  geom_tile(aes(x = time, y = max(z.score) * 1.3, fill = T.val)) +
  facet_grid(Direction ~ Event, scales = "free_y") + 
  theme_bw() + 
  #theme(legend.position = "none") +
  scale_fill_distiller(palette = "RdYlBu", na.value = NA)


# two sample t.test -----------------------------------------------------------------------------------------------
#### two sample t.test ####

d <- "By_direction_"
d <- ""

List.comprisons <- list(One = c("Freezing_80_start", "Precue_Pokes"),
                        Two = c("Correct_Go_Sound_onset", "Correct_NG_Sound_onset"),
                        Three = c("Freezing_80_stop", "Precue_Pokes"),
                        Four = c("Correct_NG_Sound_onset", "FalseAlarm_Sound_onset"),
                        Five = c("Precue_Pokes", "ITI_Pokes"),
                        Six = c("Correct_Go_Sound_onset", "FalseAlarm_Sound_onset"),
                        Seven = c("Freezing_80_start", "Correct_NG_Sound_onset"))


Two.sample.observed <- map_df(List.comprisons, read_comp, d = "")

Two.sample.observed$time <- as.numeric(as.character(Two.sample.observed$time))


time.data <- Filtered.data %>% 
  group_by(Area.brain, TreatmentID, Event, norm.time) %>%
  summarise(avg = mean(z.score), N = n(), sd = sd(z.score), sem = sd/sqrt(N)) %>%
  ungroup() %>%
  rename(time = norm.time, z.score = avg)

# filter
# comparison <- List.comprisons$One
file.temp <- paste0(comparison, collapse = " vs ")
brain <- "STN"
treat <- "BI_cmp"

time.merge <- data.frame()

for (comparison in List.comprisons) {
  
  file.temp <- paste0(comparison, collapse = " vs ")
  
  time.temp <- time.data %>% filter(Area.brain == brain & 
                                      TreatmentID == treat & 
                                      !is.na(time) &
                                      Event %in% unlist(comparison))
  
  t.temp <- Two.sample.observed %>% filter(Area.brain == brain & 
                                             TreatmentID == treat & 
                                             !is.na(time) &
                                             comparison == file.temp) %>%
    select(-comparison, -Area.brain, -TreatmentID)
  
  time.temp <- left_join(time.temp, t.temp, by = "time")
  time.temp$comparison <- file.temp
  
  time.merge <- bind_rows(time.merge, time.temp)
  
}


ggplot(data = time.merge, 
       aes(x = time, y = z.score)) +
  geom_vline(xintercept = 0, linetype = "dashed", color = "grey50") +
  geom_hline(yintercept = 0, color = "grey50") +
  geom_tile(aes(x = time, y = 0, fill = T.val), alpha = 0.7) + #max(z.score) * 1.3
  geom_line(aes(color = Event), size = 1.2) + 
  #geom_ribbon(aes(ymin = z.score - sem, ymax = z.score + sem, color = Event), alpha = 0.2, size = 1) 
  facet_wrap(~ comparison, scales = "free_y") + 
  theme_bw() + 
  #theme(legend.position = "none") +
  scale_fill_distiller(palette = "RdYlBu", na.value = NA)





precue <- Two.sample.observed %>% filter(Area.brain == "STN" & 
                                           TreatmentID == "Baseline" &
                                           n > 1)

precue.perm <- Two.sample.permuted %>% filter(Area.brain == "STN" & 
                                                TreatmentID == "Baseline" &
                                                n > 1) %>%
  group_by(iteration) %>% 
  summarise(T.val.max = max(T.val), T.val.min = min(T.val),
            sum.t.max = max(sum.t), sum.t.min = min(sum.t)) %>%
  ungroup() %>%
  mutate(T.val = if_else(abs(T.val.max) >= abs(T.val.min), T.val.max, T.val.min),
         t.max = if_else(abs(sum.t.max) >= abs(sum.t.min), sum.t.max, sum.t.min))

(sum(abs(-7.95886) < abs(precue.perm$t.max)) + 1) / (length(precue.perm$t.max) + 1)

# By direction

Two.sample.observed <- map_df(List.comprisons, read_comp, d = "By_direction")

Two.sample.observed$time <- as.numeric(as.character(Two.sample.observed$time))


time.data <- Filtered.data %>% 
  group_by(Area.brain, TreatmentID, Event, norm.time, Direction) %>%
  summarise(avg = mean(z.score), N = n(), sd = sd(z.score), sem = sd/sqrt(N)) %>%
  ungroup() %>%
  rename(time = norm.time, z.score = avg)

# filter

brain <- "STN"
treat <- "BI_cmp"

time.merge <- data.frame()

for (comparison in List.comprisons) {
  
  file.temp <- paste0(comparison, collapse = " vs ")
  
  time.temp <- time.data %>% filter(Area.brain == brain & 
                                      TreatmentID == treat & 
                                      !is.na(time) &
                                      Event %in% unlist(comparison))
  
  t.temp <- Two.sample.observed %>% filter(Area.brain == brain & 
                                             TreatmentID == treat & 
                                             !is.na(time) &
                                             comparison == file.temp) %>%
    select(-comparison, -Area.brain, -TreatmentID)
  
  time.temp <- left_join(time.temp, t.temp, by = c("time", "Direction"))
  time.temp$comparison <- file.temp
  
  time.merge <- bind_rows(time.merge, time.temp)
  
}


ggplot(data = time.merge, 
       aes(x = time, y = z.score)) +
  geom_vline(xintercept = 0, linetype = "dashed", color = "grey50") +
  geom_hline(yintercept = 0, color = "grey50") +
  #geom_line(aes(color = Event)) + 
  geom_ribbon(aes(ymin = z.score - sem, ymax = z.score + sem, color = Event), alpha = 0.2, size = 1) +
  geom_tile(aes(x = time, y = max(z.score) * 1.3, fill = T.val)) +
  facet_grid(Direction~ comparison, scales = "free_y") + 
  theme_bw() + 
  #theme(legend.position = "none") +
  scale_fill_distiller(palette = "RdYlBu", na.value = NA)
